{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# `keras-unet-collection.models` user guide\n",
    "\n",
    "This user guide requires `keras-unet-collection==0.1.9` or higher.\n",
    "\n",
    "## Content\n",
    "\n",
    "* [**U-net**](#U-net)\n",
    "* [**V-net**](#V-net)\n",
    "* [**Attention-Unet**](#Attention-Unet)\n",
    "* [**U-net++**](#U-net++)\n",
    "* [**UNET 3+**](#UNET-3+)\n",
    "* [**R2U-net**](#R2U-net)\n",
    "* [**ResUnet-a**](#ResUnet-a)\n",
    "* [**U^2-Net**](#U^2-Net)\n",
    "* [**TransUNET**](#TransUNET)\n",
    "* [**Swin-UNET**](#Swin-UNET)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow import keras"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TensorFlow 2.5.0; Keras 2.5.0\n"
     ]
    }
   ],
   "source": [
    "print('TensorFlow {}; Keras {}'.format(tf.__version__, keras.__version__))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Step 1: importing `models` from `keras_unet_collection`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras_unet_collection import models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Step 2: defining your hyper-parameters"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Commonly used hyper-parameter options are listed as follows. Full details are available through the Python helper function:\n",
    "\n",
    "* `input_size`: a tuple or list that defines the shape of input tensors. \n",
    "    * `models.resunet_a_2d`, `models.transunet_2d`, and `models.swin_unet_2d` support int only, others also support `input_size=(None, None, 3)`.\n",
    "    * `activation='PReLU'` is not compatible with `input_size=(None, None, 3)`. \n",
    "\n",
    "\n",
    "* `filter_num`: a list that defines the number of convolutional filters per down- and up-sampling blocks.\n",
    "    * For `unet_2d`, `att_unet_2d`, `unet_plus_2d`, `r2_unet_2d`, depth $\\ge$ 2 is expected.\n",
    "    * For `resunet_a_2d` and `u2net_2d`, depth $\\ge$ 3 is expected.\n",
    "\n",
    "\n",
    "* `n_labels`: number of output targets, e.g., `n_labels=2` for binary classification.\n",
    "\n",
    "* `activation`: the activation function of hidden layers. Available choices are `'ReLU'`, `'LeakyReLU'`, `'PReLU'`, `'ELU'`, `'GELU'`, `'Snake'`.\n",
    "\n",
    "* `output_activation`: the activation function of the output layer. Recommended choices are `'Sigmoid'`, `'Softmax'`, `None` (linear), `'Snake'`.\n",
    "\n",
    "* `batch_norm`: if specified as True, all convolutional layers will be configured as stacks of \"Conv2D-BN-Activation\".\n",
    "\n",
    "* `stack_num_down`: number of convolutional layers per downsampling level.\n",
    "\n",
    "* `stack_num_up`: number of convolutional layers (after concatenation) per upsampling level. \n",
    "\n",
    "* `pool`: the configuration of downsampling (encoding) blocks.\n",
    "    * `pool=False`: downsampling with a convolutional layer (2-by-2 convolution kernels with 2 strides; optional batch normalization and activation). \n",
    "    * `pool=True` or `pool='max'` downsampling with a max-pooling layer.\n",
    "    * `pool='ave'` downsampling with a average-pooling layer.\n",
    "\n",
    "\n",
    "* `unpool`: the configuration of upsampling (decoding) blocks.\n",
    "    * `unpool=False`: upsampling with a transpose convolutional layer (2-by-2 convolution kernels with 2 strides; optional batch normalization and activation). \n",
    "    * `unpool=True` or `unpool='bilinear'` upsampling with bilinear interpolation.\n",
    "    * `unpool='nearest'` upsampling with reflective padding.\n",
    "\n",
    "\n",
    "* `name`: user-specified prefix of the configured layer and model. Use `keras.models.Model.summary` to identify the exact name of each layer."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Step 3: Configuring your model\n",
    "\n",
    "**Note**\n",
    "\n",
    "Configured models can be saved through `model.save(filepath, save_traces=True)`, but they may contain python objects that are not part of the `tensorflow.keras`. Thus when loading the model, it is preferred to load the weights only, and set/freeze them within a new configuration.\n",
    "\n",
    "```python\n",
    "e.g.\n",
    "\n",
    "weights = dummy_loader(model_old_path)\n",
    "model_new = swin_transformer_model(...)\n",
    "model_new.set_weights(weights)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## U-net"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Example 1**: U-net for binary classification with:\n",
    "\n",
    "1. Five down- and upsampliung levels (or four downsampling levels and one bottom level).\n",
    "\n",
    "2. Two convolutional layers per downsampling level.\n",
    "\n",
    "3. One convolutional layer (after concatenation) per upsamling level.\n",
    "\n",
    "2. Gaussian Error Linear Unit (GELU) activcation, Softmax output activation, batch normalization.\n",
    "\n",
    "3. Downsampling through Maxpooling.\n",
    "\n",
    "4. Upsampling through reflective padding."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = models.unet_2d((None, None, 3), [64, 128, 256, 512, 1024], n_labels=2,\n",
    "                      stack_num_down=2, stack_num_up=1,\n",
    "                      activation='GELU', output_activation='Softmax', \n",
    "                      batch_norm=True, pool='max', unpool='nearest', name='unet')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## V-net"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Example 2**: Vnet (originally proposed for 3-d inputs, here modified for 2-d inputs) for binary classification with:\n",
    "\n",
    "1. Input size of (256, 256, 1); PReLU does not support input tensor with shapes of NoneType \n",
    "\n",
    "\n",
    "\n",
    "1. Five down- and upsampliung levels (or four downsampling levels and one bottom level).\n",
    "\n",
    "1. Number of stacked convolutional layers of the residual path increase with downsampling levels from one to three (symmetrically, decrease with upsampling levels).   \n",
    "    * `res_num_ini=1`\n",
    "    * `res_num_max=3`\n",
    "    \n",
    " \n",
    "2. PReLU activcation, Softmax output activation, batch normalization.\n",
    "\n",
    "3. Downsampling through stride convolutional layers.\n",
    "\n",
    "4. Upsampling through transpose convolutional layers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = models.vnet_2d((256, 256, 1), filter_num=[16, 32, 64, 128, 256], n_labels=2,\n",
    "                      res_num_ini=1, res_num_max=3, \n",
    "                      activation='PReLU', output_activation='Softmax', \n",
    "                      batch_norm=True, pool=False, unpool=False, name='vnet')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Attention-Unet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Example 3**: attention-Unet for single target regression with:\n",
    "\n",
    "1. Four down- and upsampling levels.\n",
    "\n",
    "2. Two convolutional layers per downsampling level.\n",
    "\n",
    "3. Two convolutional layers (after concatenation) per upsampling level.\n",
    "\n",
    "2. ReLU activation, linear output activation (None), batch normalization.\n",
    "\n",
    "3. Additive attention, ReLU attention activation.\n",
    "        \n",
    "4. Downsampling through stride convolutional layers.\n",
    "\n",
    "5. Upsampling through bilinear interpolation.   \n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = models.att_unet_2d((None, None, 3), [64, 128, 256, 512], n_labels=1,\n",
    "                           stack_num_down=2, stack_num_up=2,\n",
    "                           activation='ReLU', atten_activation='ReLU', attention='add', output_activation=None, \n",
    "                           batch_norm=True, pool=False, unpool='bilinear', name='attunet')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## U-net++"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Example 4**: U-net++ for three-label classification with:\n",
    "\n",
    "1. Four down- and upsampling levels.\n",
    "\n",
    "2. Two convolutional layers per downsampling level.\n",
    "\n",
    "3. Two convolutional layers (after concatenation) per upsampling level.\n",
    "\n",
    "2. LeakyReLU activation, Softmax output activation, no batch normalization.\n",
    "        \n",
    "3. Downsampling through Maxpooling.\n",
    "\n",
    "4. Upsampling through transpose convolutional layers.\n",
    "\n",
    "5. Deep supervision."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------\n",
      "deep_supervision = True\n",
      "names of output tensors are listed as follows (\"sup0\" is the shallowest supervision layer;\n",
      "\"final\" is the final output layer):\n",
      "\n",
      "\txnet_output_sup0_activation\n",
      "\txnet_output_sup1_activation\n",
      "\txnet_output_sup2_activation\n",
      "\txnet_output_final_activation\n"
     ]
    }
   ],
   "source": [
    "model = models.unet_plus_2d((None, None, 3), [64, 128, 256, 512], n_labels=3,\n",
    "                            stack_num_down=2, stack_num_up=2,\n",
    "                            activation='LeakyReLU', output_activation='Softmax', \n",
    "                            batch_norm=False, pool='max', unpool=False, deep_supervision=True, name='xnet')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## UNET 3+"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Example 5**: UNet 3+ for binary classification with:\n",
    "\n",
    "1. Four down- and upsampling levels.\n",
    "\n",
    "2. Two convolutional layers per downsampling level.\n",
    "\n",
    "3. One convolutional layers (after concatenation) per upsampling level.\n",
    "\n",
    "2. ReLU activation, Sigmoid output activation, batch normalization.\n",
    "        \n",
    "3. Downsampling through Maxpooling.\n",
    "\n",
    "4. Upsampling through transpose convolutional layers.\n",
    "\n",
    "5. Deep supervision."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Automated hyper-parameter determination is applied with the following details:\n",
      "----------\n",
      "\tNumber of convolution filters after each full-scale skip connection: filter_num_skip = [64, 64, 64]\n",
      "\tNumber of channels of full-scale aggregated feature maps: filter_num_aggregate = 256\n",
      "----------\n",
      "deep_supervision = True\n",
      "names of output tensors are listed as follows (\"sup0\" is the shallowest supervision layer;\n",
      "\"final\" is the final output layer):\n",
      "\n",
      "\tunet3plus_output_sup0_activation\n",
      "\tunet3plus_output_sup1_activation\n",
      "\tunet3plus_output_sup2_activation\n",
      "\tunet3plus_output_final_activation\n"
     ]
    }
   ],
   "source": [
    "model = models.unet_3plus_2d((128, 128, 3), n_labels=2, filter_num_down=[64, 128, 256, 512], \n",
    "                             filter_num_skip='auto', filter_num_aggregate='auto', \n",
    "                             stack_num_down=2, stack_num_up=1, activation='ReLU', output_activation='Sigmoid',\n",
    "                             batch_norm=True, pool='max', unpool=False, deep_supervision=True, name='unet3plus')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* `filter_num_skip` and `filter_num_aggregate` can be specified explicitly:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------\n",
      "deep_supervision = True\n",
      "names of output tensors are listed as follows (\"sup0\" is the shallowest supervision layer;\n",
      "\"final\" is the final output layer):\n",
      "\n",
      "\tunet3plus_output_sup0_activation\n",
      "\tunet3plus_output_sup1_activation\n",
      "\tunet3plus_output_sup2_activation\n",
      "\tunet3plus_output_final_activation\n"
     ]
    }
   ],
   "source": [
    "model = models.unet_3plus_2d((128, 128, 3), n_labels=2, filter_num_down=[64, 128, 256, 512], \n",
    "                             filter_num_skip=[64, 64, 64], filter_num_aggregate=256, \n",
    "                             stack_num_down=2, stack_num_up=1, activation='ReLU', output_activation='Sigmoid',\n",
    "                             batch_norm=True, pool='max', unpool=False, deep_supervision=True, name='unet3plus')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## R2U-net"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Example 6**: R2U-net for binary classification with:\n",
    "\n",
    "1. Four down- and upsampling levels.\n",
    "\n",
    "2. Two recurrent convolutional layers with two iterations per down- and upsampling level.\n",
    "\n",
    "2. ReLU activation, Softmax output activation, no batch normalization.\n",
    "        \n",
    "3. Downsampling through Maxpooling.\n",
    "\n",
    "4. Upsampling through reflective padding."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = models.r2_unet_2d((None, None, 3), [64, 128, 256, 512], n_labels=2,\n",
    "                          stack_num_down=2, stack_num_up=1, recur_num=2,\n",
    "                          activation='ReLU', output_activation='Softmax', \n",
    "                          batch_norm=True, pool='max', unpool='nearest', name='r2unet')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ResUnet-a"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Example 7**: ResUnet-a for 16-label classification with:\n",
    "\n",
    "1. input size of (128, 128, 3)\n",
    "\n",
    "1. Six downsampling levels followed by an Atrous Spatial Pyramid Pooling (ASPP) layer with 256 filters.\n",
    "\n",
    "1. Six upsampling levels followed by an ASPP layer with 128 filters.\n",
    "\n",
    "2. dilation rates of {1, 3, 15, 31} for shallow layers, {1,3,15} for intermediate layers, and {1,} for deep layers.\n",
    "\n",
    "3. ReLU activation, Sigmoid output activation, batch normalization.\n",
    "\n",
    "4. Downsampling through stride convolutional layers.\n",
    "\n",
    "4. Upsampling through reflective padding."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Received dilation rates: [1, 3, 15, 31]\n",
      "Received dilation rates are not defined on a per downsampling level basis.\n",
      "Automated determinations are applied with the following details:\n",
      "\tdepth-0, dilation_rate = [1, 3, 15, 31]\n",
      "\tdepth-1, dilation_rate = [1, 3, 15, 31]\n",
      "\tdepth-2, dilation_rate = [1, 3, 15]\n",
      "\tdepth-3, dilation_rate = [1, 3, 15]\n",
      "\tdepth-4, dilation_rate = [1]\n",
      "\tdepth-5, dilation_rate = [1]\n"
     ]
    }
   ],
   "source": [
    "model = models.resunet_a_2d((128, 128, 3), [32, 64, 128, 256, 512, 1024], \n",
    "                            dilation_num=[1, 3, 15, 31], \n",
    "                            n_labels=16, aspp_num_down=256, aspp_num_up=128, \n",
    "                            activation='ReLU', output_activation='Sigmoid', \n",
    "                            batch_norm=True, pool=False, unpool='nearest', name='resunet')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* `dilation_num` can be specified per down- and uplampling level:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = models.resunet_a_2d((128, 128, 3), [32, 64, 128, 256, 512, 1024], \n",
    "                            dilation_num=[[1, 3, 15, 31], [1, 3, 15, 31], [1, 3, 15], [1, 3, 15], [1,], [1,],],\n",
    "                            n_labels=16, aspp_num_down=256, aspp_num_up=128, \n",
    "                            activation='ReLU', output_activation='Sigmoid', \n",
    "                            batch_norm=True, pool=False, unpool='nearest', name='resunet')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## U^2-Net"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Example 8**: U^2-Net for binary classification with:\n",
    "\n",
    "1. Six downsampling levels with the first four layers built with RSU, and the last two (one downsampling layer, one bottom layer) built with RSU-F4.\n",
    "    * `filter_num_down=[64, 128, 256, 512]`\n",
    "    * `filter_mid_num_down=[32, 32, 64, 128]`\n",
    "    * `filter_4f_num=[512, 512]`\n",
    "    * `filter_4f_mid_num=[256, 256]`\n",
    "    \n",
    "    \n",
    "1. Six upsampling levels with the deepest layer built with RSU-F4, and the other four layers built with RSU.\n",
    "    * `filter_num_up=[64, 64, 128, 256]`\n",
    "    * `filter_mid_num_up=[16, 32, 64, 128]`\n",
    "    \n",
    "    \n",
    "3. ReLU activation, Sigmoid output activation, batch normalization.\n",
    "\n",
    "4. Deep supervision\n",
    "\n",
    "5. Downsampling through stride convolutional layers.\n",
    "\n",
    "6. Upsampling through transpose convolutional layers.\n",
    "\n",
    "*In the original work of U^2-Net, down- and upsampling were achieved through maxpooling (`pool=True` or `pool='max'`) and bilinear interpolation (`unpool=True` or unpool=`'bilinear'`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "----------\n",
      "The depth of u2net_2d = len(filter_num_down) + len(filter_4f_num) = 6\n",
      "----------\n",
      "deep_supervision = True\n",
      "names of output tensors are listed as follows (\"sup0\" is the shallowest supervision layer;\n",
      "\"final\" is the final output layer):\n",
      "\n",
      "\tu2net_output_sup0_trans_conv\n",
      "\tu2net_output_sup1_trans_conv\n",
      "\tu2net_output_sup2_trans_conv\n",
      "\tu2net_output_sup3_trans_conv\n",
      "\tu2net_output_sup4_trans_conv\n",
      "\tu2net_output_sup5_trans_conv\n",
      "\tu2net_output_final\n"
     ]
    }
   ],
   "source": [
    "model = models.u2net_2d((128, 128, 3), n_labels=2, \n",
    "                        filter_num_down=[64, 128, 256, 512], filter_num_up=[64, 64, 128, 256], \n",
    "                        filter_mid_num_down=[32, 32, 64, 128], filter_mid_num_up=[16, 32, 64, 128], \n",
    "                        filter_4f_num=[512, 512], filter_4f_mid_num=[256, 256], \n",
    "                        activation='ReLU', output_activation=None, \n",
    "                        batch_norm=True, pool=False, unpool=False, deep_supervision=True, name='u2net')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "* `u2net_2d` supports automated determination of filter numbers per down- and upsampling level. Auto-mode may produce a slightly larger network."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Automated hyper-parameter determination is applied with the following details:\n",
      "----------\n",
      "\tNumber of RSU output channels within downsampling blocks: filter_num_down = [64, 128, 256, 512]\n",
      "\tNumber of RSU intermediate channels within downsampling blocks: filter_mid_num_down = [16, 32, 64, 128]\n",
      "\tNumber of RSU output channels within upsampling blocks: filter_num_up = [64, 128, 256, 512]\n",
      "\tNumber of RSU intermediate channels within upsampling blocks: filter_mid_num_up = [16, 32, 64, 128]\n",
      "\tNumber of RSU-4F output channels within downsampling and bottom blocks: filter_4f_num = [512, 512]\n",
      "\tNumber of RSU-4F intermediate channels within downsampling and bottom blocks: filter_4f_num = [256, 256]\n",
      "----------\n",
      "Explicitly specifying keywords listed above if their \"auto\" settings do not satisfy your needs\n",
      "----------\n",
      "The depth of u2net_2d = len(filter_num_down) + len(filter_4f_num) = 6\n",
      "----------\n",
      "deep_supervision = True\n",
      "names of output tensors are listed as follows (\"sup0\" is the shallowest supervision layer;\n",
      "\"final\" is the final output layer):\n",
      "\n",
      "\tu2net_output_sup0_activation\n",
      "\tu2net_output_sup1_activation\n",
      "\tu2net_output_sup2_activation\n",
      "\tu2net_output_sup3_activation\n",
      "\tu2net_output_sup4_activation\n",
      "\tu2net_output_sup5_activation\n",
      "\tu2net_output_final_activation\n"
     ]
    }
   ],
   "source": [
    "model = models.u2net_2d((None, None, 3), n_labels=2, \n",
    "                        filter_num_down=[64, 128, 256, 512],\n",
    "                        activation='ReLU', output_activation='Sigmoid', \n",
    "                        batch_norm=True, pool=False, unpool=False, deep_supervision=True, name='u2net')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## TransUNET"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Example 9**: TransUNET for 12-label classification with:\n",
    "\n",
    "* input size of (512, 512, 3)\n",
    "\n",
    "* Four down- and upsampling levels.\n",
    "\n",
    "* Two convolutional layers per downsampling level.\n",
    "\n",
    "* Two convolutional layers (after concatenation) per upsampling level.\n",
    "\n",
    "* 12 transformer blocks (`num_transformer=12`).\n",
    "\n",
    "* 12 attention heads (`num_heads=12`).\n",
    "\n",
    "* 3072 MLP nodes per vision transformer (`num_mlp=3072`).\n",
    "\n",
    "* 768 embeding dimensions (`embed_dim=768`).\n",
    "\n",
    "* Gaussian Error Linear Unit (GELU) activcation for transformer MLPs.\n",
    "\n",
    "* ReLU activation, softmax output activation, batch normalization.\n",
    "\n",
    "* Downsampling through maxpooling.\n",
    "\n",
    "* Upsampling through bilinear interpolation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = models.transunet_2d((512, 512, 3), filter_num=[64, 128, 256, 512], n_labels=12, stack_num_down=2, stack_num_up=2,\n",
    "                                embed_dim=768, num_mlp=3072, num_heads=12, num_transformer=12,\n",
    "                                activation='ReLU', mlp_activation='GELU', output_activation='Softmax', \n",
    "                                batch_norm=True, pool=True, unpool='bilinear', name='transunet')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Swin-UNET"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Example 10**: Swin-UNET for 3-label classification with:\n",
    "\n",
    "* input size of (128, 128, 3)\n",
    "\n",
    "* Four down- and upsampling levels (or three downsampling levels and one bottom level) (`depth=4`).\n",
    "\n",
    "* Two Swin-Transformers per downsampling level.\n",
    "\n",
    "* Two Swin-Transformers (after concatenation) per upsampling level.\n",
    "\n",
    "* Extract 2-by-2 patches from the input (`patch_size=(2, 2)`)\n",
    "\n",
    "* Embed 2-by-2 patches to 64 dimensions (`filter_num_begin=64`, a.k.a, number of embedded dimensions).\n",
    "\n",
    "* Number of attention heads for each down- and upsampling level: `num_heads=[4, 8, 8, 8]`.\n",
    "\n",
    "* Size of attention windows for each down- and upsampling level: `window_size=[4, 2, 2, 2]`.\n",
    "\n",
    "* 512 nodes per Swin-Transformer (`num_mlp=512`)\n",
    "\n",
    "* Shift attention windows (i.e., Swin-MSA) (`shift_window=True`)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = models.swin_unet_2d((128, 128, 3), filter_num_begin=64, n_labels=3, depth=4, stack_num_down=2, stack_num_up=2, \n",
    "                            patch_size=(2, 2), num_heads=[4, 8, 8, 8], window_size=[4, 2, 2, 2], num_mlp=512, \n",
    "                            output_activation='Softmax', shift_window=True, name='swin_unet')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
